import torch
import string
import os
from typing import List, Dict, Iterator, Tuple
import openai

from transformers import BertTokenizer, BertForMaskedLM
from transformers import AutoModelForMaskedLM, AutoTokenizer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tenacity import retry, stop_after_attempt, wait_exponential
from transformers import logging

logging.set_verbosity_error()

from attack.score import Text
from abc import ABC, abstractmethod
from typing import List, Dict

class Substitutor(ABC):
    def __init__(self, text: Text, k: int):
        self.text = text
        self.k = k
    
    @abstractmethod
    def suggestions(self, indx: int) -> List[str]:
        """Get k suggestions for the word at given index"""
        pass
    
    def generate_substitutions(self) -> Dict[int, List[str]]:
        """Generate substitutions for all words in the text"""
        return {i: self.suggestions(i) for i in range(len(self.text))}

class ModernBertSubstitutor(Substitutor):
    def __init__(self, k: int = 30, model_name: str = "answerdotai/ModernBERT-large", device: str = 'cuda:1'):
        self.model_name = model_name
        self.masked_model = None
        self.masked_tokenizer = None
        self.k = k
        self.device = device
        self.text_cache = {}
        self.load_model()

    def load_model(self):
        if self.masked_model is None or self.masked_tokenizer is None:
            token = os.environ.get('HF_TOKEN')
            print(f"Downloading and loading model {self.model_name}...")
            self.masked_tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=token)
            self.masked_model = AutoModelForMaskedLM.from_pretrained(self.model_name, token=token)
            self.masked_model.to(self.device)
            print("Model loaded successfully.")
        else:
            print("Model already loaded.")

    def process_text(self, text: Text, indx: int) -> str:
        """Process new text and store its predictions in cache"""
        text_id = hash(str(text)+str(indx))
        if text_id not in self.text_cache:
            self.text_cache[text_id] = {
                'text': text,
                'predictions': self._get_mask_predictions(text, indx)
            }
        return text_id

    def suggestions(self, text: Text, indx: int) -> List[str]:
        text_id = self.process_text(text, indx)
        candidates = self.text_cache[text_id]['predictions']

        candidates = [r for r in candidates if r not in string.punctuation]
        
        return candidates

    def _get_mask_predictions(self, text: Text, indx: int) -> List[str]:
        masked_text = text.replace(indx, self.masked_tokenizer.mask_token)
        d = self._token_substitution_preds(self.masked_model, self.masked_tokenizer, masked_text, self.k)
        return [self.masked_tokenizer.decode([candidate], skip_special_tokens=True) 
                for candidate in d]

    @staticmethod
    def _token_substitution_preds(masked_model, masked_tokenizer, text, k):

        inputs = masked_tokenizer(str(text), return_tensors='pt').to(masked_model.device)
        mlm_predictions = masked_model(**inputs, do_sample=False, temperature=None)
        masked_index = inputs["input_ids"][0].tolist().index(masked_tokenizer.mask_token_id)
        top_k_ids = mlm_predictions.logits[0, masked_index].topk(k).indices.tolist()

        return top_k_ids

    def clear_cache(self):
        """Clear the text predictions cache"""
        self.text_cache = {}

class GPTSubstitutor(Substitutor):
    def __init__(self, k: int, model_name: str = "gpt-4o-mini", device: str = 'None'):
        self.model_name = model_name
        self.k = k
        self.substitution_cache = {}  # Cache for storing predictions with format {word: [suggestions]}
        self._load_api_key()

    def _load_api_key(self):
        # 1) Try environment variable
        self.api_key = os.getenv("OPENAI_API_KEY")

        # 2) Fallback to file if not in env
        if not self.api_key:
            try:
                with open('./openai_api_key.key', 'r') as f:
                    self.api_key = f.read().strip()
            except FileNotFoundError:
                raise RuntimeError(
                    "OpenAI API key not found. "
                    "Please set the OPENAI_API_KEY env var or create './openai_api_key.key'."
                )

        # 3) Instantiate client
        self.client = openai.OpenAI(api_key=self.api_key)
        
    def _get_completion(self, prompt: str) -> str:
        """
        Get completion from OpenAI API with retry logic
        """
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{
                    "role": "system",
                    "content": """You are a language expert specializing in contextual word substitutions. Your task is to provide alternative words that fit grammatically and semantically within a modified sentence while ensuring that the overall meaning remains as close as possible to the initial sentence.
                    When suggesting substitutions, ensure they are grammatically correct, contextually appropriate, and maintain the intent of the original sentence before modification. Avoid providing generic, out-of-context, or nonsensical words. Avoid word repetitions.
                    If you cannot provide the requested number of meaningful substitutions, suggest only those that truly preserve the sentence’s original intent."""
                }, {
                    "role": "user",
                    "content": prompt
                }],
                temperature=0,
                max_tokens=50,
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"Error calling OpenAI API: {e}")
            raise

    def suggestions(self, text: Text, indx: int) -> List[str]:
        """
        Get k suggestions for the word at the given index
        """
        # Check cache first
        word_to_replace = text.words[indx]
        if word_to_replace in self.substitution_cache:
            cached_suggestions = self.substitution_cache[word_to_replace]
            # Return only k suggestions even from cache
            return cached_suggestions[:self.k]
        
        # Create prompt for GPT
        context = str(text)
        
        prompt = f""""Given the sentence: "{context}.

Your task is to provide up to {self.k} alternative words for '{word_to_replace}' that fit naturally within the modified sentence while ensuring that the overall meaning remains as close as possible to the initial sentence.

The substitutions must be grammatically correct, semantically appropriate, and must not alter the intended message. If you cannot provide {self.k} fully meaningful substitutions, list only those that truly preserve the sentence’s original intent.

Do not include generic, out-of-context, or nonsensical words. Avoid repeating words already present in the modified sentence.
Only provide the words, one per line, and do not include any explanations or punctuation.
        """

        # Get predictions from GPT
        response = self._get_completion(prompt)
        
        # Process the response
        suggestions = [
            word.strip() for word in response.split('\n')
            if word.strip() and word.strip() not in string.punctuation
            and word.strip() != word_to_replace
        ]

        # Ensure only meaningful suggestions are kept
        if len(suggestions) > self.k:
            suggestions = suggestions[:self.k]

        # Cache the results
        self.substitution_cache[word_to_replace] = suggestions   

        return suggestions


